{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import math\n",
    "import torch_geometric\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.model_selection import KFold\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "from os.path import join\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import ChemicalFeatures, AllChem\n",
    "from rdkit import RDConfig\n",
    "from rdkit.Chem.rdmolfiles import MolFromMolFile\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import ast\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "import random\n",
    "import torch\n",
    "import deepdish as dd\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import Sequential, Linear, ReLU, GRU, BatchNorm1d, Dropout\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.nn import NNConv, Set2Set, GCNConv, global_add_pool, global_mean_pool,GATConv,GINConv\n",
    "from torch.nn import Sequential, Linear, ReLU\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch_geometric.utils import remove_self_loops\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.data import InMemoryDataset #easily fits into cpu memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_edges_list(root, code):\n",
    "    ligand_filename = code + \"_h.sdf\"\n",
    "#     print(join(root, code, ligand_filename))\n",
    "    m = MolFromMolFile(join(root, code, ligand_filename))\n",
    "    atoms1 = [b.GetBeginAtomIdx() for b in m.GetBonds()]\n",
    "    atoms2 = [b.GetEndAtomIdx() for b in m.GetBonds()]    \n",
    "    # Edge attributes: distance; SINGLE; DOUBLE; TRIPLE; AROMATIC.\n",
    "    edge_weights= []\n",
    "    coords = m.GetConformers()[0].GetPositions()  # Get a const reference to the vector of atom positions\n",
    "    for b in m.GetBonds():\n",
    "        if str(b.GetBondType()) == \"SINGLE\":\n",
    "            edge_weights.append(1)\n",
    "        elif str(b.GetBondType()) == \"DOUBLE\":\n",
    "            edge_weights.append(2)\n",
    "        elif str(b.GetBondType()) == \"TRIPLE\":\n",
    "            edge_weights.append(3)\n",
    "        else:\n",
    "            edge_weights.append(4)\n",
    "    edge_features = np.array(edge_weights) \n",
    "    # since the torch-geometric graphs are directed add reverse direction of edges\n",
    "    return np.array([atoms1 + atoms2, atoms2 + atoms1]), np.concatenate((edge_features, edge_features), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PDBbindDataset(InMemoryDataset):\n",
    "    def __init__(self, root, node_features,activity_csv,transform=None, pre_transform=None):\n",
    "        self.root = root\n",
    "        self.node_features = node_features\n",
    "        self.activity_csv = activity_csv\n",
    "        super(PDBbindDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.data, self.slices = torch.load(self.processed_paths[0])\n",
    "\n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return [\"data.pt\"]\n",
    "\n",
    "    def download(self):\n",
    "        pass\n",
    "\n",
    "    def process(self):\n",
    "\n",
    "        self.node_data = dd.io.load(join(self.root, self.node_features))\n",
    "#         print(self.node_data.keys())\n",
    "        # load csv with activity data and simlarity scores\n",
    "        self.activity = pd.read_csv(join(self.root, self.activity_csv))\n",
    "        # create lists of edges and edge descriptors \n",
    "        self.edge_indexes = {key: add_edges_list(self.root, key)[0] for key in self.activity.PDB }\n",
    "        self.edge_data = {key: add_edges_list(self.root, key)[1] for key in self.activity.PDB }\n",
    "        \n",
    "        # Read data into huge `Data` list.\n",
    "        data_list = [Data(x = torch.FloatTensor(self.node_data[key]),\n",
    "                          edge_index = torch.LongTensor(self.edge_indexes[key]),\n",
    "                          edge_attr = torch.FloatTensor(self.edge_data[key]),\n",
    "                          y = torch.FloatTensor([self.activity[self.activity.PDB == key].pk.iloc[0]])) for key in self.activity.PDB ]\n",
    "      \n",
    "        if self.pre_filter is not None:\n",
    "            data_list = [data for data in data_list if self.pre_filter(data)]\n",
    "\n",
    "        if self.pre_transform is not None:\n",
    "            data_list = [self.pre_transform(data) for data in data_list]\n",
    "        \n",
    "        data, slices = self.collate(data_list)\n",
    "        torch.save((data, slices), self.processed_paths[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = PDBbindDataset(\"refined-set\", \n",
    "                         \"refined_set.h5\",\n",
    "                         \"index/refined_data2020.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        #GCN-representation\n",
    "        self.conv1 = GCNConv(373, 256, cached=False )\n",
    "        self.bn01 = BatchNorm1d(256)\n",
    "        self.conv2 = GCNConv(256, 128, cached=False )\n",
    "        self.bn02 = BatchNorm1d(128)\n",
    "        self.conv3 = GCNConv(128, 128, cached=False)\n",
    "        self.bn03 = BatchNorm1d(128)\n",
    "        #GAT-representation\n",
    "        self.gat1 = GATConv(373, 256,heads=3)\n",
    "        self.bn11 = BatchNorm1d(256*3)\n",
    "        self.gat2 = GATConv(256*3, 128,heads=3)\n",
    "        self.bn12 = BatchNorm1d(128*3)\n",
    "        self.gat3 = GATConv(128*3, 128,heads=3)\n",
    "        self.bn13 = BatchNorm1d(128*3)\n",
    "        #GIN-representation\n",
    "        fc_gin1=Sequential(Linear(373, 256), ReLU(), Linear(256, 256))\n",
    "        self.gin1 = GINConv(fc_gin1)\n",
    "        self.bn21 = BatchNorm1d(256)\n",
    "        fc_gin2=Sequential(Linear(256, 128), ReLU(), Linear(128, 128))\n",
    "        self.gin2 = GINConv(fc_gin2)\n",
    "        self.bn22 = BatchNorm1d(128)\n",
    "        fc_gin3=Sequential(Linear(128, 64), ReLU(), Linear(64, 64))\n",
    "        self.gin3 = GINConv(fc_gin3)\n",
    "        self.bn23 = BatchNorm1d(64)\n",
    "        #Fully connected layers for concatinating outputs\n",
    "        self.fc1=Linear(128*4 + 64, 256)\n",
    "        self.dropout1=Dropout(p=0.2,)\n",
    "        self.fc2=Linear(256, 64)\n",
    "        self.dropout2=Dropout(p=0.2,)\n",
    "        self.fc3=Linear(64, 1)\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "        y=x\n",
    "        z=x\n",
    "        #GCN-representation\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "        x = self.bn01(x)\n",
    "        x = F.relu(self.conv2(x, edge_index))\n",
    "        x = self.bn02(x)\n",
    "        x = F.relu(self.conv3(x, edge_index))\n",
    "        x = self.bn03(x)\n",
    "        x = global_add_pool(x, data.batch)\n",
    "        #GAT-representation\n",
    "        y = F.relu(self.gat1(y, edge_index))\n",
    "        y = self.bn11(y)\n",
    "        y = F.relu(self.gat2(y, edge_index))\n",
    "        y = self.bn12(y)\n",
    "        y = F.relu(self.gat3(y, edge_index))\n",
    "        y = self.bn13(y)\n",
    "        y = global_add_pool(y, data.batch)\n",
    "        #GIN-representation\n",
    "        z = F.relu(self.gin1(z, edge_index))\n",
    "        z = self.bn21(z)\n",
    "        z = F.relu(self.gin2(z, edge_index))\n",
    "        z = self.bn22(z)\n",
    "        z = F.relu(self.gin3(z, edge_index))\n",
    "        z = self.bn23(z)\n",
    "        z = global_add_pool(z, data.batch)\n",
    "        #Concatinating_representations\n",
    "        cr=torch.cat((x,y,z),1)\n",
    "        cr = F.relu(self.fc1(cr))\n",
    "        cr = self.dropout1(cr)\n",
    "        cr = F.relu(self.fc2(cr))\n",
    "        cr = self.dropout2(cr)\n",
    "        cr = self.fc3(cr)\n",
    "        cr = F.relu(cr).view(-1)\n",
    "        return cr  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "def train(model, train_loader,epoch,device,optimizer,scheduler):\n",
    "    model.train()\n",
    "    loss_all = 0\n",
    "    error = 0\n",
    "    \n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        loss = F.mse_loss(model(data), data.y)\n",
    "        loss.backward()\n",
    "        loss_all += loss.item() * data.num_graphs\n",
    "        error += (model(data) - data.y).abs().sum().item()  # MAE\n",
    "        torch.nn.utils.clip_grad_value_(model.parameters(), 1)\n",
    "        optimizer.step()\n",
    "    return loss_all / len(train_loader.dataset), error / len(train_loader.dataset)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model, loader,device):\n",
    "    model.eval()\n",
    "    error = 0\n",
    "\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        error += (model(data) - data.y).abs().sum().item()  # MAE\n",
    "    return error / len(loader.dataset)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test_preidstions(model, loader):\n",
    "    model.eval()\n",
    "    pred = []\n",
    "    true = []\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        pred += model(data).detach().cpu().numpy().tolist()\n",
    "        true += data.y.detach().cpu().numpy().tolist()\n",
    "    return pred, true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "f = open('train_val_ids.json')\n",
    "ids = json.load(f)\n",
    "train_ids_random = ids['Random_split_train_ids'] \n",
    "val_ids_random = ids['Random_split_val_ids'] \n",
    "train_ids_PPS = ids['PPS_train_ids'] \n",
    "val_ids_PPS = ids['PPS_val_ids'] \n",
    "train_ids_LLS = ids['LLS_train_ids'] \n",
    "val_ids_LLS = ids['LLS_val_ids'] \n",
    "train_ids_CCS = ids['CCS_train_ids'] \n",
    "val_ids_CCS = ids['CCS_val_ids']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = []\n",
    "val_set = []\n",
    "for idx, graph in enumerate(dataset):\n",
    "    if(idx in train_ids_random):\n",
    "        train_set.append(graph)\n",
    "    else:\n",
    "        val_set.append(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "def train(model, train_loader,epoch,device,optimizer,scheduler):\n",
    "    model.train()\n",
    "    loss_all = 0\n",
    "    error = 0\n",
    "    \n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        loss = F.mse_loss(model(data), data.y)\n",
    "        loss.backward()\n",
    "        loss_all += loss.item() * data.num_graphs\n",
    "        error += (model(data) - data.y).abs().sum().item()  # MAE\n",
    "        torch.nn.utils.clip_grad_value_(model.parameters(), 1)\n",
    "        optimizer.step()\n",
    "    return loss_all / len(train_loader.dataset), error / len(train_loader.dataset)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model, loader,device):\n",
    "    model.eval()\n",
    "    error = 0\n",
    "\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        error += (model(data) - data.y).abs().sum().item()  # MAE\n",
    "    return error / len(loader.dataset)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test_predictions(model, loader):\n",
    "    model.eval()\n",
    "    pred = []\n",
    "    true = []\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        pred += model(data).detach().cpu().numpy().tolist()\n",
    "        true += data.y.detach().cpu().numpy().tolist()\n",
    "    return pred, true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/ghaith/opt/anaconda3/envs/pynew/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001, LR: 0.0010000, Loss: 9.2164399, Validation MAE: 1.9831091\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb Cell 10\u001b[0m in \u001b[0;36m<cell line: 18>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=17'>18</a>\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m1\u001b[39m, \u001b[39m1001\u001b[39m):\n\u001b[1;32m     <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=18'>19</a>\u001b[0m     lr \u001b[39m=\u001b[39m scheduler\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39mparam_groups[\u001b[39m0\u001b[39m][\u001b[39m'\u001b[39m\u001b[39mlr\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=19'>20</a>\u001b[0m     loss, train_error \u001b[39m=\u001b[39m train(model, train_loader,epoch,device,optimizer,scheduler)\n\u001b[1;32m     <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=20'>21</a>\u001b[0m     val_error \u001b[39m=\u001b[39m test(model, val_loader,device)\n\u001b[1;32m     <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=21'>22</a>\u001b[0m     train_errors\u001b[39m.\u001b[39mappend(train_error)\n",
      "\u001b[1;32m/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb Cell 10\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, train_loader, epoch, device, optimizer, scheduler)\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m     <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m loss \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mmse_loss(model(data), data\u001b[39m.\u001b[39my)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=10'>11</a>\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m     <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=11'>12</a>\u001b[0m loss_all \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m loss\u001b[39m.\u001b[39mitem() \u001b[39m*\u001b[39m data\u001b[39m.\u001b[39mnum_graphs\n\u001b[1;32m     <a href='vscode-notebook-cell:/Users/ghaith/Desktop/graphLambda/GNNs_fusion_AA_embedding_MC-Output.ipynb#X26sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m error \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m (model(data) \u001b[39m-\u001b[39m data\u001b[39m.\u001b[39my)\u001b[39m.\u001b[39mabs()\u001b[39m.\u001b[39msum()\u001b[39m.\u001b[39mitem()  \u001b[39m# MAE\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/pynew/lib/python3.9/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    478\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m    479\u001b[0m     \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    480\u001b[0m         Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m    481\u001b[0m         (\u001b[39mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    486\u001b[0m         inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m    487\u001b[0m     )\n\u001b[0;32m--> 488\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m    489\u001b[0m     \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m    490\u001b[0m )\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/pynew/lib/python3.9/site-packages/torch/autograd/__init__.py:197\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    192\u001b[0m     retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m    194\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m    195\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    196\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 197\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward(  \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    198\u001b[0m     tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m    199\u001b[0m     allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "device = 'cpu'\n",
    "train_loader = DataLoader(\n",
    "train_set, \n",
    "batch_size=64, shuffle = True,drop_last=True)\n",
    "\n",
    "val_loader = DataLoader(\n",
    "val_set,\n",
    "batch_size=64, shuffle = True , drop_last = True)\n",
    "best_val_error = None\n",
    "best_model = None\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "train_errors, valid_errors,test_errors = [], [],[]\n",
    "model = Net().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',\n",
    "                                factor=0.95, patience=10,\n",
    "                                min_lr=0.00001)\n",
    "for epoch in range(1, 1001):\n",
    "    lr = scheduler.optimizer.param_groups[0]['lr']\n",
    "    loss, train_error = train(model, train_loader,epoch,device,optimizer,scheduler)\n",
    "    val_error = test(model, val_loader,device)\n",
    "    train_errors.append(train_error)\n",
    "    valid_errors.append(val_error)\n",
    "\n",
    "    if best_val_error is None or val_error <= best_val_error:\n",
    "        best_val_error = val_error\n",
    "        best_model = copy.deepcopy(model)\n",
    "\n",
    "    print('Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}'\n",
    "        .format(epoch, lr, loss, val_error))\n",
    "print('leng of test errors = ', len(test_errors))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import r2_score, mean_squared_error\n",
    "\n",
    "coreset = PDBbindDataset(\"coreset2016\", \n",
    "                         \"coreset.h5\",\n",
    "                         \"coreset2016.csv\")\n",
    "test_loader = DataLoader(coreset, batch_size=64, shuffle=False)\n",
    "pred, true = test_predictions(model, test_loader)\n",
    "\n",
    "plt.plot(pred, true, \"r.\")\n",
    "plt.plot(np.unique(true), np.poly1d(np.polyfit(pred, true, 1))(np.unique(true)))\n",
    "plt.text(9., 4., \"RMSE = \" + str(math.sqrt(mean_squared_error(true, pred)))[:5])\n",
    "plt.text(9., 6., \"R^2 = \" + str(r2_score(true, pred))[:5])\n",
    "plt.xlabel(\"predicted constants\")\n",
    "plt.ylabel(\"true constants\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pynew",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
